{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "a642a12f",
   "metadata": {},
   "source": [
    "# Matching General Use Case\n",
    "\n",
    "Matching is a widely used causal inference analysis tool, the main idea is to improve statistical model performance by preprocessing data with nonparametric or non-parametric matching methods. After preprocessing data with CausalMatch, researchers can use whatever parametric model they would have used without CausalMatch, but produce inferences with substantially more robustness and less sensitivity to modeling assumptions.\n",
    "\n",
    "### Data\n",
    "\n",
    "We have following types of observations:\n",
    "* Covariates, which we will denote with `X`\n",
    "* Treatment, which we will denote with `T`\n",
    "* Responses, which we will denote with `Y`\n",
    "\n",
    "Requirement is that `T` is a binary varible which contain only 0/1 values.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "b4cac6b0",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['/Library/Frameworks/Python.framework/Versions/3.12/lib/python312.zip', '/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12', '/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/lib-dynload', '', '/Users/bytedance/Library/Python/3.12/lib/python/site-packages', '/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages']\n",
      "current version is:  0.0.5\n",
      "/usr/local/bin/python3.12\n",
      "3.12.2 (v3.12.2:6abddd9f6a, Feb  6 2024, 17:02:06) [Clang 13.0.0 (clang-1300.0.29.30)]\n",
      "sys.version_info(major=3, minor=12, micro=2, releaselevel='final', serial=0)\n"
     ]
    }
   ],
   "source": [
    "# %load_ext autoreload\n",
    "# %autoreload 2\n",
    "# import sys\n",
    "# print(sys.path)\n",
    "# sys.path.append('..')\n",
    "\n",
    "import causalmatch as causalmatch\n",
    "from causalmatch import matching,gen_test_data\n",
    "\n",
    "print('current version is: ',causalmatch.__version__)\n",
    "\n",
    "import sys\n",
    "print(sys.executable)\n",
    "print(sys.version)\n",
    "print(sys.version_info)\n",
    "\n",
    "\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import statsmodels.api as sm\n",
    "from sklearn.linear_model import LogisticRegression\n",
    "from sklearn.svm import SVC\n",
    "from sklearn.naive_bayes import GaussianNB\n",
    "from sklearn.naive_bayes import MultinomialNB\n",
    "from sklearn.neighbors import KNeighborsClassifier\n",
    "from sklearn.tree import DecisionTreeClassifier\n",
    "from sklearn.ensemble import RandomForestClassifier\n",
    "from sklearn.ensemble import GradientBoostingClassifier\n",
    "from lightgbm import LGBMClassifier\n",
    "from xgboost import XGBClassifier\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.metrics import recall_score,roc_auc_score,f1_score\n",
    "import statsmodels.api as sm"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9c50ea99",
   "metadata": {},
   "source": [
    "## 1. Generate synthetic data for example"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "8ccbd8c7",
   "metadata": {},
   "outputs": [],
   "source": [
    "df, rand_continuous, rand_true_param, param_te , rand_treatment, rand_error = gen_test_data(n = 10000, c_ratio=0.5)\n",
    "X = ['c_1', 'c_2', 'c_3', 'd_1', 'gender']\n",
    "y = ['y', 'y2']\n",
    "id = 'user_id'\n",
    "\n",
    "# treatment variable has to be a 0/1 dummy variable\n",
    "# if is string, please convert to a 0/1 int input\n",
    "T = 'treatment'"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ed0d9035",
   "metadata": {},
   "source": [
    "## 2. PSM Demo\n",
    "###  2.1 Simple PSM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "62795d70",
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processed Samples: 100%|████████████| 5097/5097 [00:00<00:00, 627778.45sample/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "******************************\n",
      "post-matching balance check reseult        Covariates  Mean Treated post-match  Mean Control post-match   SMD  \\\n",
      "0             c_1                   0.5052                   0.5145 -0.03   \n",
      "1             c_2                   0.5026                   0.4954  0.03   \n",
      "2             c_3                   0.5090                   0.5003  0.03   \n",
      "3         d_1_bee                   0.1123                   0.1097  0.01   \n",
      "4         d_1_cat                   0.1584                   0.1446  0.04   \n",
      "5         d_1_dog                   0.1934                   0.1828  0.03   \n",
      "6        d_1_pear                   0.2567                   0.2787 -0.05   \n",
      "7     gender_cat1                   0.0721                   0.0708  0.01   \n",
      "8     gender_cat2                   0.0865                   0.0861  0.00   \n",
      "9     gender_cat3                   0.1717                   0.2007 -0.08   \n",
      "10    gender_cat4                   0.0210                   0.0151  0.04   \n",
      "11    gender_cat5                   0.1154                   0.1094  0.02   \n",
      "12     gender_dog                   0.1807                   0.1765  0.01   \n",
      "13  gender_female                   0.2087                   0.2090 -0.00   \n",
      "14    gender_male                   0.0238                   0.0306 -0.04   \n",
      "15     gender_pig                   0.0693                   0.0642  0.02   \n",
      "\n",
      "    Var Ratio  ks-p_val  ttest-p_val  \n",
      "0        1.08     0.003        0.109  \n",
      "1        1.06     0.003        0.209  \n",
      "2        1.05     0.003        0.138  \n",
      "3         NaN     1.000        0.687  \n",
      "4         NaN     0.771        0.066  \n",
      "5         NaN     0.959        0.198  \n",
      "6         NaN     0.214        0.017  \n",
      "7         NaN     1.000        0.805  \n",
      "8         NaN     1.000        0.938  \n",
      "9         NaN     0.042        0.000  \n",
      "10        NaN     1.000        0.034  \n",
      "11        NaN     1.000        0.370  \n",
      "12        NaN     1.000        0.601  \n",
      "13        NaN     1.000        0.963  \n",
      "14        NaN     1.000        0.047  \n",
      "15        NaN     1.000        0.334  \n",
      "pre-matching balance check reseult        Covariates  Mean Treated pre-match  Mean Control pre-match   SMD  \\\n",
      "0             c_1                  0.5004                  0.4984  0.01   \n",
      "1             c_2                  0.5024                  0.5062 -0.01   \n",
      "2             c_3                  0.5033                  0.5053 -0.01   \n",
      "3         d_1_bee                  0.1124                  0.1201 -0.02   \n",
      "4         d_1_cat                  0.1581                  0.1577  0.00   \n",
      "5         d_1_dog                  0.1929                  0.1997 -0.02   \n",
      "6        d_1_pear                  0.2551                  0.2541  0.00   \n",
      "7     gender_cat1                  0.0714                  0.0671  0.02   \n",
      "8     gender_cat2                  0.0883                  0.0938 -0.02   \n",
      "9     gender_cat3                  0.1683                  0.1691 -0.00   \n",
      "10    gender_cat4                  0.0247                  0.0288 -0.03   \n",
      "11    gender_cat5                  0.1156                  0.1167 -0.00   \n",
      "12     gender_dog                  0.1776                  0.1756  0.01   \n",
      "13  gender_female                  0.2046                  0.1964  0.02   \n",
      "14    gender_male                  0.0263                  0.0249  0.01   \n",
      "15     gender_pig                  0.0736                  0.0793 -0.02   \n",
      "\n",
      "    Var Ratio  ks-p_val  ttest-p_val  \n",
      "0        0.99     0.406        0.722  \n",
      "1        0.99     0.736        0.516  \n",
      "2        0.99     0.904        0.736  \n",
      "3         NaN     0.998        0.229  \n",
      "4         NaN     1.000        0.948  \n",
      "5         NaN     1.000        0.391  \n",
      "6         NaN     1.000        0.916  \n",
      "7         NaN     1.000        0.396  \n",
      "8         NaN     1.000        0.336  \n",
      "9         NaN     1.000        0.921  \n",
      "10        NaN     1.000        0.211  \n",
      "11        NaN     1.000        0.863  \n",
      "12        NaN     1.000        0.798  \n",
      "13        NaN     0.995        0.305  \n",
      "14        NaN     1.000        0.656  \n",
      "15        NaN     1.000        0.278  \n",
      "sensitivity test result:     Wilcoxon-statistic  gamma  stat upper bound  stat_lower_bound  \\\n",
      "0           7814466.0    1.0        6496126.50        6496126.50   \n",
      "1           7814466.0    1.5        7795351.80        5196901.20   \n",
      "2           7814466.0    2.0        8661502.00        4330751.00   \n",
      "3           7814466.0    2.5        9280180.71        3712072.29   \n",
      "4           6805230.0    1.0        6496126.50        6496126.50   \n",
      "5           6805230.0    1.5        7795351.80        5196901.20   \n",
      "6           6805230.0    2.0        8661502.00        4330751.00   \n",
      "7           6805230.0    2.5        9280180.71        3712072.29   \n",
      "\n",
      "   z-score upper bound  z-score lower bound  p-val upper bound  \\\n",
      "0            12.548218            12.548218            0.00000   \n",
      "1             0.185684            25.428257            0.42635   \n",
      "2                  NaN                  NaN                NaN   \n",
      "3                  NaN                  NaN                NaN   \n",
      "4             2.942109             2.942109            0.00163   \n",
      "5                  NaN                  NaN                NaN   \n",
      "6                  NaN                  NaN                NaN   \n",
      "7                  NaN                  NaN                NaN   \n",
      "\n",
      "   p-val lower bound   y  \n",
      "0            0.00000   y  \n",
      "1            0.00000   y  \n",
      "2                NaN   y  \n",
      "3                NaN   y  \n",
      "4            0.00163  y2  \n",
      "5                NaN  y2  \n",
      "6                NaN  y2  \n",
      "7                NaN  y2  \n",
      "Output dataframe columns Index(['user_id', 'treatment', 'pscore'], dtype='object')\n",
      "                  0         1\n",
      "const      6.379187  0.034522\n",
      "treatment  0.513786  0.085681\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "# -------------------------------------------------------- #\n",
    "# STEP 1: initialize matching object\n",
    "match_obj = matching(data = df,\n",
    "                     T = T,\n",
    "                     X = X,\n",
    "                     id = id)\n",
    "# -------------------------------------------------------- #\n",
    "\n",
    "# STEP 2: propensity score matching\n",
    "match_obj.psm(n_neighbors = 1,  # number of neighbors\n",
    "              model = GradientBoostingClassifier(), # p-score model\n",
    "              trim_percentage = 0.1, # trim X% of pscore, if equals 0.1 then trim min 5% and max 5%\n",
    "              caliper = 0.1,\n",
    "              verbose = True) # p-score diff must be smaller than or equal to the caliper value\n",
    "\n",
    "# # if you with to keep all matched observations\n",
    "# match_obj.psm(n_neighbors = 1, \n",
    "#               model = GradientBoostingClassifier(), \n",
    "#               trim_percentage = 0, \n",
    "#               caliper = 1) \n",
    "\n",
    "# -------------------------------------------------------- #\n",
    "\n",
    "# STEP 3: balance check after propensity score matching\n",
    "res_post, res_pre = match_obj.balance_check(include_discrete = True)\n",
    "print('post-matching balance check reseult',res_post) # pre-match balance check\n",
    "print('pre-matching balance check reseult',res_pre)  # post-match balance check\n",
    "\n",
    "# -------------------------------------------------------- #\n",
    "\n",
    "# STEP 4: sensitivity test\n",
    "match_obj.y = y\n",
    "df_sensitivity_test = match_obj.sensitivity_test(gamma = [1,1.5,2,2.5])\n",
    "print('sensitivity test result: ', df_sensitivity_test)\n",
    "\n",
    "# -------------------------------------------------------- #\n",
    "\n",
    "# STEP 5: get result - pandas df, and merge X and y back to original data\n",
    "# print('Matched pair dataFrame columns', match_obj.df_out_final_post_trim_pair.columns)\n",
    "print('Output dataframe columns', match_obj.df_out_final_post_trim.columns)\n",
    "df_out = match_obj.df_out_final_post_trim.merge(df[y + X + [id]], how='left', on = id)\n",
    "\n",
    "# -------------------------------------------------------- #\n",
    "\n",
    "# STEP 6: calculate average treatment effect \n",
    "\n",
    "X_mat = df_out[T]\n",
    "y_mat = df_out[y]\n",
    "\n",
    "X_mat = sm.add_constant(X_mat)\n",
    "model = sm.OLS(y_mat,X_mat)\n",
    "results = model.fit()\n",
    "print(results.params)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "33eeb568",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAjAAAAGdCAYAAAAMm0nCAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/H5lhTAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAlxUlEQVR4nO3dfXRU1b3/8U8eJ4EyE0Azk7RBo0UeFEWJxhFsRXKJQr1yS6ssUxbWSFpN7AXqAylPVZFoSpWCCIWiwG0st95VvRZtahoEqozBBmgpIGqJAnInwUJmeCh5PL8/XJyfA0FJmMlkh/drrbOWs88+s79ni8zHfc6cibEsyxIAAIBBYqNdAAAAQHsRYAAAgHEIMAAAwDgEGAAAYBwCDAAAMA4BBgAAGIcAAwAAjEOAAQAAxomPdgGR0traqgMHDqhXr16KiYmJdjkAAOAsWJalI0eOKD09XbGxZ15n6bYB5sCBA8rIyIh2GQAAoAP27dunr33ta2fc320DTK9evSR9NgFOpzPK1QAAgLMRDAaVkZFhf46fSbcNMCcvGzmdTgIMAACG+bLbP7iJFwAAGIcAAwAAjEOAAQAAxum298AAAGCqlpYWNTU1RbuMiIiLi1N8fPw5P+KEAAMAQBdy9OhR7d+/X5ZlRbuUiOnRo4fS0tKUmJjY4fcgwAAA0EW0tLRo//796tGjhy688MJu9yBWy7LU2NiogwcPqqamRv379//Ch9V9EQIMAABdRFNTkyzL0oUXXqjk5ORolxMRycnJSkhI0Mcff6zGxkYlJSV16H24iRcAgC6mu628nKqjqy4h7xGGOgAAADoVAQYAABiHe2AAAOjinql4v1PHm/pvl3XqeB3BCgwAADAOAQYAABiHAAMAAM7JwYMH5fF4NG/ePLtt06ZNSkxMVGVlZUTG5B4YAEbzrXjwtDZv/vwoVAKcvy688EI9//zzGjdunEaPHq0BAwZo4sSJKioq0qhRoyIyJgEGAACcszFjxmjy5MnKy8tTVlaWevbsqZKSkoiNxyUkAAAQFvPnz1dzc7NeeukllZWVyeFwRGwsAgwAAAiLf/zjHzpw4IBaW1v10UcfRXQsLiEBAIBz1tjYqO9973u68847NWDAAN17773avn27UlNTIzIeKzAAAOCczZgxQ4FAQAsXLtQjjzyiyy67TPfcc0/ExmMFBgCALq6rPxl3/fr1WrBggd588005nU5J0n/913/pqquu0pIlS3TfffeFfUwCDAAAOCc33XSTmpqaQtouvvhiBQKBiI3JJSQAAGAcAgwAADAOAQYAABiHAAMAAIxDgAEAAMYhwAAAAOMQYAAAgHEIMAAAwDgEGAAAYByexAsAQFf3ZknnjjeyuHPH64B2r8Bs3LhRt912m9LT0xUTE6NXXnklZL9lWZo9e7bS0tKUnJysnJwcffDBByF9Dh06pLy8PDmdTqWkpCg/P19Hjx4N6fO3v/1NN954o5KSkpSRkaHS0tL2nx0AAOiW2h1gjh07pquuukqLFy9uc39paakWLlyopUuXqqqqSj179lRubq5OnDhh98nLy9OOHTtUUVGhtWvXauPGjSooKLD3B4NBjR49WhdddJGqq6v1s5/9TD/96U+1bNmyDpwiAACIpNWrV6tv375qaGgIaR83bpwmTpwYkTHbHWBuvfVWzZ07V//xH/9x2j7LsrRgwQLNnDlTt99+u6688kqtXr1aBw4csFdqdu3apfLycv3qV79Sdna2RowYoUWLFmnNmjU6cOCAJKmsrEyNjY16/vnndfnll2vChAn60Y9+pKeffvrczhYAAITdd7/7XbW0tOjVV1+12+rq6vTaa6/pnnvuiciYYb2Jt6amRn6/Xzk5OXaby+VSdna2fD6fJMnn8yklJUVZWVl2n5ycHMXGxqqqqsru841vfEOJiYl2n9zcXO3evVuHDx9uc+yGhgYFg8GQDQAARF5ycrLuuusuvfDCC3bbr3/9a/Xr10833XRTRMYMa4Dx+/2SJLfbHdLudrvtfX6/X6mpqSH74+Pj1adPn5A+bb3H58c4VUlJiVwul71lZGSc+wkBAICzMnnyZL3xxhv65JNPJEkrV67U3XffrZiYmIiM122+Rl1cXKxAIGBv+/bti3ZJAACcN66++mpdddVVWr16taqrq7Vjxw7dfffdERsvrF+j9ng8kqTa2lqlpaXZ7bW1tRo6dKjdp66uLuS45uZmHTp0yD7e4/GotrY2pM/J1yf7nMrhcMjhcITlPAAAQPvde++9WrBggT755BPl5ORE9GpIWFdgMjMz5fF4VFlZabcFg0FVVVXJ6/VKkrxer+rr61VdXW33WbdunVpbW5WdnW332bhxo5qamuw+FRUVGjBggHr37h3OkgEAQJjcdddd2r9/v5YvXx6xm3dPavcKzNGjR/Xhhx/ar2tqarRt2zb16dNH/fr105QpUzR37lz1799fmZmZmjVrltLT0zVu3DhJ0qBBg3TLLbdo8uTJWrp0qZqamlRUVKQJEyYoPT1d0mcT8Oijjyo/P1+PPPKI/v73v+sXv/iFnnnmmfCcNYBuzbfiwZDX3vz5UaoEOL+4XC6NHz9er732mv25HyntDjB/+ctfNHLkSPv1tGnTJEmTJk3SypUr9fDDD+vYsWMqKChQfX29RowYofLyciUlJdnHlJWVqaioSKNGjVJsbKzGjx+vhQsX2vtdLpfeeOMNFRYWatiwYbrgggs0e/bskGfFAABw3jDgybgnffLJJ8rLy4v4bR0xlmVZER0hSoLBoFwulwKBgJxOZ7TLARAhp662tIUVGJjixIkTqqmpUWZmZsj/+Jvg8OHDWr9+vb7zne9o586dGjBgwBn7ftF5nu3nN7+FBAAAztnVV1+tw4cP66mnnvrC8BIuBBgAAHDOPvroo04dr9s8BwYAAJw/CDAAAMA4BBgAALqYbvr9Gls4zo8AAwBAFxEXFydJamxsjHIlkXX8+HFJUkJCQoffg5t4AQDoIuLj49WjRw8dPHhQCQkJio3tXusMlmXp+PHjqqurU0pKih3YOoIAAwBAFxETE6O0tDTV1NTo448/jnY5EZOSknLG3zY8WwQYAAC6kMTERPXv37/bXkZKSEg4p5WXkwgwAAB0MbGxscY9ibezda+LawAA4LxAgAEAAMYhwAAAAOMQYAAAgHEIMAAAwDgEGAAAYBwCDAAAMA4BBgAAGIcAAwAAjEOAAQAAxiHAAAAA4xBgAACAcQgwAADAOAQYAABgHAIMAAAwDgEGAAAYhwADAACMQ4ABAADGIcAAAADjEGAAAIBxCDAAAMA4BBgAAGAcAgwAADAOAQYAABiHAAMAAIxDgAEAAMYhwAAAAOMQYAAAgHEIMAAAwDgEGAAAYBwCDAAAMA4BBgAAGIcAAwAAjEOAAQAAxiHAAAAA4xBgAACAcQgwAADAOAQYAABgHAIMAAAwDgEGAAAYhwADAACMQ4ABAADGIcAAAADjEGAAAIBxCDAAAMA4YQ8wLS0tmjVrljIzM5WcnKxLL71Ujz/+uCzLsvtYlqXZs2crLS1NycnJysnJ0QcffBDyPocOHVJeXp6cTqdSUlKUn5+vo0ePhrtcAABgoLAHmKeeekpLlizRs88+q127dumpp55SaWmpFi1aZPcpLS3VwoULtXTpUlVVValnz57Kzc3ViRMn7D55eXnasWOHKioqtHbtWm3cuFEFBQXhLhcAABgoPtxvuGnTJt1+++0aO3asJOniiy/Wb37zG23evFnSZ6svCxYs0MyZM3X77bdLklavXi23261XXnlFEyZM0K5du1ReXq53331XWVlZkqRFixZpzJgxmj9/vtLT08NdNgAAMEjYV2BuuOEGVVZW6v3335ck/fWvf9Vbb72lW2+9VZJUU1Mjv9+vnJwc+xiXy6Xs7Gz5fD5Jks/nU0pKih1eJCknJ0exsbGqqqpqc9yGhgYFg8GQDQAAdE9hX4GZPn26gsGgBg4cqLi4OLW0tOiJJ55QXl6eJMnv90uS3G53yHFut9ve5/f7lZqaGlpofLz69Olj9zlVSUmJHn300XCfDgAA6ILCvgLz29/+VmVlZXrxxRe1ZcsWrVq1SvPnz9eqVavCPVSI4uJiBQIBe9u3b19ExwMAANET9hWYhx56SNOnT9eECRMkSUOGDNHHH3+skpISTZo0SR6PR5JUW1urtLQ0+7ja2loNHTpUkuTxeFRXVxfyvs3NzTp06JB9/KkcDoccDke4TwcAAHRBYV+BOX78uGJjQ982Li5Ora2tkqTMzEx5PB5VVlba+4PBoKqqquT1eiVJXq9X9fX1qq6utvusW7dOra2tys7ODnfJAADAMGFfgbntttv0xBNPqF+/frr88su1detWPf3007rnnnskSTExMZoyZYrmzp2r/v37KzMzU7NmzVJ6errGjRsnSRo0aJBuueUWTZ48WUuXLlVTU5OKioo0YcIEvoEEAADCH2AWLVqkWbNm6f7771ddXZ3S09P1gx/8QLNnz7b7PPzwwzp27JgKCgpUX1+vESNGqLy8XElJSXafsrIyFRUVadSoUYqNjdX48eO1cOHCcJcLAAAMFGN9/hG53UgwGJTL5VIgEJDT6Yx2OQAixLfiwS/t482f3wmVAAiHs/385reQAACAcQgwAADAOAQYAABgHAIMAAAwDgEGAAAYhwADAACMQ4ABAADGIcAAAADjEGAAAIBxCDAAAMA4BBgAAGAcAgwAADAOAQYAABiHAAMAAIxDgAEAAMYhwAAAAOMQYAAAgHEIMAAAwDgEGAAAYBwCDAAAMA4BBgAAGIcAAwAAjEOAAQAAxiHAAAAA4xBgAACAcQgwAADAOAQYAABgHAIMAAAwDgEGAAAYhwADAACMQ4ABAADGIcAAAADjEGAAAIBxCDAAAMA4BBgAAGAcAgwAADBOfLQLAID28K14MNolAOgCWIEBAADGIcAAAADjEGAAAIBxCDAAAMA4BBgAAGAcAgwAADAOAQYAABiHAAMAAIxDgAEAAMYhwAAAAOMQYAAAgHEIMAAAwDgEGAAAYBwCDAAAMA4BBgAAGIcAAwAAjEOAAQAAxiHAAAAA40QkwHzyySf63ve+p759+yo5OVlDhgzRX/7yF3u/ZVmaPXu20tLSlJycrJycHH3wwQch73Ho0CHl5eXJ6XQqJSVF+fn5Onr0aCTKBQAAhgl7gDl8+LCGDx+uhIQE/eEPf9DOnTv185//XL1797b7lJaWauHChVq6dKmqqqrUs2dP5ebm6sSJE3afvLw87dixQxUVFVq7dq02btyogoKCcJcLAAAMFGNZlhXON5w+fbrefvtt/fnPf25zv2VZSk9P149//GM9+OCDkqRAICC3262VK1dqwoQJ2rVrlwYPHqx3331XWVlZkqTy8nKNGTNG+/fvV3p6+pfWEQwG5XK5FAgE5HQ6w3eCAKLKt+LBdh/jzZ8fgUoARMLZfn6HfQXm1VdfVVZWlr773e8qNTVVV199tZYvX27vr6mpkd/vV05Ojt3mcrmUnZ0tn88nSfL5fEpJSbHDiyTl5OQoNjZWVVVVbY7b0NCgYDAYsgEAgO4p7AFmz549WrJkifr3768//vGPuu+++/SjH/1Iq1atkiT5/X5JktvtDjnO7Xbb+/x+v1JTU0P2x8fHq0+fPnafU5WUlMjlctlbRkZGuE8NAAB0EWEPMK2trbrmmms0b948XX311SooKNDkyZO1dOnScA8Vori4WIFAwN727dsX0fEAAED0hD3ApKWlafDgwSFtgwYN0t69eyVJHo9HklRbWxvSp7a21t7n8XhUV1cXsr+5uVmHDh2y+5zK4XDI6XSGbAAAoHsKe4AZPny4du/eHdL2/vvv66KLLpIkZWZmyuPxqLKy0t4fDAZVVVUlr9crSfJ6vaqvr1d1dbXdZ926dWptbVV2dna4SwYAAIaJD/cbTp06VTfccIPmzZunO+64Q5s3b9ayZcu0bNkySVJMTIymTJmiuXPnqn///srMzNSsWbOUnp6ucePGSfpsxeaWW26xLz01NTWpqKhIEyZMOKtvIAEAgO4t7AHm2muv1csvv6zi4mI99thjyszM1IIFC5SXl2f3efjhh3Xs2DEVFBSovr5eI0aMUHl5uZKSkuw+ZWVlKioq0qhRoxQbG6vx48dr4cKF4S4XAAAYKOzPgekqeA4M0D3xHBige4vac2AAAAAijQADAACMQ4ABAADGIcAAAADjEGAAAIBxCDAAAMA4BBgAAGAcAgwAADAOAQYAABiHAAMAAIxDgAEAAMYhwAAAAOMQYAAAgHEIMAAAwDgEGAAAYBwCDAAAMA4BBgAAGIcAAwAAjEOAAQAAxiHAAAAA4xBgAACAcQgwAADAOAQYAABgHAIMAAAwDgEGAAAYhwADAACMQ4ABAADGIcAAAADjEGAAAIBxCDAAAMA4BBgAAGAcAgwAADAOAQYAABiHAAMAAIxDgAEAAMYhwAAAAOMQYAAAgHEIMAAAwDjx0S4AACLuzZLT20YWd34dAMKGAAOg2/Pt+edpbd6RUSgEQNhwCQkAABiHAAMAAIxDgAEAAMYhwAAAAOMQYAAAgHEIMAAAwDgEGAAAYBwCDAAAMA4BBgAAGIcAAwAAjEOAAQAAxiHAAAAA4xBgAACAcQgwAADAOAQYAABgnIgHmCeffFIxMTGaMmWK3XbixAkVFhaqb9+++spXvqLx48ertrY25Li9e/dq7Nix6tGjh1JTU/XQQw+pubk50uUCAAADRDTAvPvuu/rlL3+pK6+8MqR96tSp+v3vf6+XXnpJGzZs0IEDB/Ttb3/b3t/S0qKxY8eqsbFRmzZt0qpVq7Ry5UrNnj07kuUCAABDRCzAHD16VHl5eVq+fLl69+5ttwcCAa1YsUJPP/20br75Zg0bNkwvvPCCNm3apHfeeUeS9MYbb2jnzp369a9/raFDh+rWW2/V448/rsWLF6uxsTFSJQMAAENELMAUFhZq7NixysnJCWmvrq5WU1NTSPvAgQPVr18/+Xw+SZLP59OQIUPkdrvtPrm5uQoGg9qxY0eb4zU0NCgYDIZsAACge4qPxJuuWbNGW7Zs0bvvvnvaPr/fr8TERKWkpIS0u91u+f1+u8/nw8vJ/Sf3taWkpESPPvpoGKoHAABdXdhXYPbt26f//M//VFlZmZKSksL99mdUXFysQCBgb/v27eu0sQEAQOcKe4Cprq5WXV2drrnmGsXHxys+Pl4bNmzQwoULFR8fL7fbrcbGRtXX14ccV1tbK4/HI0nyeDynfSvp5OuTfU7lcDjkdDpDNgAA0D2FPcCMGjVK27dv17Zt2+wtKytLeXl59j8nJCSosrLSPmb37t3au3evvF6vJMnr9Wr79u2qq6uz+1RUVMjpdGrw4MHhLhkAABgm7PfA9OrVS1dccUVIW8+ePdW3b1+7PT8/X9OmTVOfPn3kdDr1wAMPyOv16vrrr5ckjR49WoMHD9bEiRNVWloqv9+vmTNnqrCwUA6HI9wlAwAAw0TkJt4v88wzzyg2Nlbjx49XQ0ODcnNz9dxzz9n74+LitHbtWt13333yer3q2bOnJk2apMceeywa5QIAgC4mxrIsK9pFREIwGJTL5VIgEOB+GKAb8a14MCzv482fH5b3ARBeZ/v5zW8hAQAA4xBgAACAcQgwAADAOAQYAABgHAIMAAAwDgEGAAAYhwADAACMQ4ABAADGIcAAAADjEGAAAIBxCDAAAMA4BBgAAGAcAgwAADBOfLQLAIAzCdcvTwPofliBAQAAxmEFBsB56dTVHW/+/ChVAqAjWIEBAADGIcAAAADjEGAAAIBxCDAAAMA4BBgAAGAcAgwAADAOAQYAABiHAAMAAIxDgAEAAMYhwAAAAOMQYAAAgHEIMAAAwDgEGAAAYBwCDAAAMA4BBgAAGIcAAwAAjEOAAQAAxiHAAAAA4xBgAACAcQgwAADAOAQYAABgHAIMAAAwDgEGAAAYhwADAACMQ4ABAADGIcAAAADjEGAAAIBxCDAAAMA48dEuAAC6hDdLTm8bWdz5dQA4KwQYAJDk2/PP09q8I6NQCICzwiUkAABgHAIMAAAwDgEGAAAYhwADAACMQ4ABAADGIcAAAADjEGAAAIBxCDAAAMA4BBgAAGCcsAeYkpISXXvtterVq5dSU1M1btw47d69O6TPiRMnVFhYqL59++orX/mKxo8fr9ra2pA+e/fu1dixY9WjRw+lpqbqoYceUnNzc7jLBQAABgp7gNmwYYMKCwv1zjvvqKKiQk1NTRo9erSOHTtm95k6dap+//vf66WXXtKGDRt04MABffvb37b3t7S0aOzYsWpsbNSmTZu0atUqrVy5UrNnzw53uQAAwEAxlmVZkRzg4MGDSk1N1YYNG/SNb3xDgUBAF154oV588UV95zvfkSS99957GjRokHw+n66//nr94Q9/0Le+9S0dOHBAbrdbkrR06VI98sgjOnjwoBITE7903GAwKJfLpUAgIKfTGclTBBAhvhUPRnV8b/78qI4PnI/O9vM74vfABAIBSVKfPn0kSdXV1WpqalJOTo7dZ+DAgerXr598Pp8kyefzaciQIXZ4kaTc3FwFg0Ht2LGjzXEaGhoUDAZDNgAA0D1FNMC0trZqypQpGj58uK644gpJkt/vV2JiolJSUkL6ut1u+f1+u8/nw8vJ/Sf3taWkpEQul8veMjIywnw2AACgq4hogCksLNTf//53rVmzJpLDSJKKi4sVCATsbd++fREfEwAAREd8pN64qKhIa9eu1caNG/W1r33Nbvd4PGpsbFR9fX3IKkxtba08Ho/dZ/PmzSHvd/JbSif7nMrhcMjhcIT5LAAAQFcU9hUYy7JUVFSkl19+WevWrVNmZmbI/mHDhikhIUGVlZV22+7du7V37155vV5Jktfr1fbt21VXV2f3qaiokNPp1ODBg8NdMgAAMEzYV2AKCwv14osv6n//93/Vq1cv+54Vl8ul5ORkuVwu5efna9q0aerTp4+cTqceeOABeb1eXX/99ZKk0aNHa/DgwZo4caJKS0vl9/s1c+ZMFRYWssoCAADCH2CWLFkiSbrppptC2l944QXdfffdkqRnnnlGsbGxGj9+vBoaGpSbm6vnnnvO7hsXF6e1a9fqvvvuk9frVc+ePTVp0iQ99thj4S4XAAAYKOLPgYkWngMDmI/nwADnny7zHBgAAIBwI8AAAADjEGAAAIBxCDAAAMA4EXuQHQC0R7Rv2AVgFlZgAACAcQgwAADAOAQYAABgHAIMAAAwDgEGAAAYhwADAACMQ4ABAADGIcAAAADjEGAAAIBxCDAAAMA4BBgAAGAcfgsJAM7g1N9n8ubPj1IlAE7FCgwAADAOAQYAABiHAAMAAIxDgAEAAMYhwAAAAOMQYAAAgHH4GjWAqDj1K8oA0B6swAAAAOOwAgMAZ6mtVSMebgdEByswAADAOAQYAABgHC4hAYAp3iw5vW1kcefXAXQBrMAAAADjEGAAAIBxCDAAAMA4BBgAAGAcAgwAADAOAQYAABiHAAMAAIzDc2AA4Bw8U/H+aW1T/+2yKFQCnF8IMABwDq7fu6yNVn4fCYg0AgwAhNupT8zlablA2BFgACDMfHv+GfLaOzJKhQDdGDfxAgAA47ACAwCRxo8wAmHHCgwAADAOKzAAIs634sFolxBVp94TI0nvNId+/ZqvXgPtwwoMAAAwDgEGAAAYh0tIABAFpz8Aj4ffAe3BCgwAADAOAQYAABiHS0gA0AXwo5BA+7ACAwAAjEOAAQAAxuESEoBzcr4/pC6i2voJAgCSCDAA0CWc/rVqSZf07fxCAEMQYACgi2rrJwhO5R355e/DDcLojrr0PTCLFy/WxRdfrKSkJGVnZ2vz5s3RLgkAupRnKt4P2YDzRZcNMP/93/+tadOmac6cOdqyZYuuuuoq5ebmqq6uLtqlAQCAKIuxLMuKdhFtyc7O1rXXXqtnn31WktTa2qqMjAw98MADmj59+pceHwwG5XK5FAgE5HQ6I10ucN7gpt2u5Z1+BR06LlKXkM5mFYjLV/giZ/v53SXvgWlsbFR1dbWKi4vtttjYWOXk5Mjn87V5TENDgxoaGuzXgUBA0mcTAeDsbF49I9oloJ1OHDvaoeNKXtkS5krOXnf6e3nxug9DXhfe/PUoVdJ9nPzz8WXrK10ywHz66adqaWmR2+0OaXe73XrvvffaPKakpESPPvroae0ZGRkRqREAuoZno11Au/0k2gVEUHc+t8525MgRuVyuM+7vkgGmI4qLizVt2jT7dWtrqw4dOqS+ffsqJiYmbOMEg0FlZGRo3759XJqKMOa6czDPnYN57hzMc+eI5DxblqUjR44oPT39C/t1yQBzwQUXKC4uTrW1tSHttbW18ng8bR7jcDjkcDhC2lJSUiJVopxOJ/9xdBLmunMwz52Dee4czHPniNQ8f9HKy0ld8ltIiYmJGjZsmCorK+221tZWVVZWyuv1RrEyAADQFXTJFRhJmjZtmiZNmqSsrCxdd911WrBggY4dO6bvf//70S4NAABEWZcNMHfeeacOHjyo2bNny+/3a+jQoSovLz/txt7O5nA4NGfOnNMuVyH8mOvOwTx3Dua5czDPnaMrzHOXfQ4MAADAmXTJe2AAAAC+CAEGAAAYhwADAACMQ4ABAADGIcC0YfHixbr44ouVlJSk7Oxsbd68+Qv7v/TSSxo4cKCSkpI0ZMgQvf76651UqfnaM9fLly/XjTfeqN69e6t3797Kycn50n83+Ex7/0yftGbNGsXExGjcuHGRLbCbaO8819fXq7CwUGlpaXI4HLrsssv4++MstHeeFyxYoAEDBig5OVkZGRmaOnWqTpw40UnVmmnjxo267bbblJ6erpiYGL3yyitfesz69et1zTXXyOFw6Otf/7pWrlwZ2SIthFizZo2VmJhoPf/889aOHTusyZMnWykpKVZtbW2b/d9++20rLi7OKi0ttXbu3GnNnDnTSkhIsLZv397JlZunvXN91113WYsXL7a2bt1q7dq1y7r77rstl8tl7d+/v5MrN0t75/mkmpoa66tf/ap14403WrfffnvnFGuw9s5zQ0ODlZWVZY0ZM8Z66623rJqaGmv9+vXWtm3bOrlys7R3nsvKyiyHw2GVlZVZNTU11h//+EcrLS3Nmjp1aidXbpbXX3/dmjFjhvW73/3OkmS9/PLLX9h/z549Vo8ePaxp06ZZO3futBYtWmTFxcVZ5eXlEauRAHOK6667ziosLLRft7S0WOnp6VZJSUmb/e+44w5r7NixIW3Z2dnWD37wg4jW2R20d65P1dzcbPXq1ctatWpVpErsFjoyz83NzdYNN9xg/epXv7ImTZpEgDkL7Z3nJUuWWJdcconV2NjYWSV2C+2d58LCQuvmm28OaZs2bZo1fPjwiNbZnZxNgHn44Yetyy+/PKTtzjvvtHJzcyNWF5eQPqexsVHV1dXKycmx22JjY5WTkyOfz9fmMT6fL6S/JOXm5p6xPz7Tkbk+1fHjx9XU1KQ+ffpEqkzjdXSeH3vsMaWmpio/P78zyjReR+b51VdfldfrVWFhodxut6644grNmzdPLS0tnVW2cToyzzfccIOqq6vty0x79uzR66+/rjFjxnRKzeeLaHwWdtkn8UbDp59+qpaWltOe9ut2u/Xee++1eYzf72+zv9/vj1id3UFH5vpUjzzyiNLT00/7jwb/X0fm+a233tKKFSu0bdu2Tqiwe+jIPO/Zs0fr1q1TXl6eXn/9dX344Ye6//771dTUpDlz5nRG2cbpyDzfdddd+vTTTzVixAhZlqXm5mb98Ic/1E9+8pPOKPm8cabPwmAwqH/9619KTk4O+5iswMBITz75pNasWaOXX35ZSUlJ0S6n2zhy5IgmTpyo5cuX64ILLoh2Od1aa2urUlNTtWzZMg0bNkx33nmnZsyYoaVLl0a7tG5l/fr1mjdvnp577jlt2bJFv/vd7/Taa6/p8ccfj3ZpOEeswHzOBRdcoLi4ONXW1oa019bWyuPxtHmMx+NpV398piNzfdL8+fP15JNP6k9/+pOuvPLKSJZpvPbO8z/+8Q999NFHuu222+y21tZWSVJ8fLx2796tSy+9NLJFG6gjf57T0tKUkJCguLg4u23QoEHy+/1qbGxUYmJiRGs2UUfmedasWZo4caLuvfdeSdKQIUN07NgxFRQUaMaMGYqN5f/jw+FMn4VOpzMiqy8SKzAhEhMTNWzYMFVWVtptra2tqqyslNfrbfMYr9cb0l+SKioqztgfn+nIXEtSaWmpHn/8cZWXlysrK6szSjVae+d54MCB2r59u7Zt22Zv//7v/66RI0dq27ZtysjI6MzyjdGRP8/Dhw/Xhx9+aAdESXr//feVlpZGeDmDjszz8ePHTwspJ0OjxU8Bhk1UPgsjdnuwodasWWM5HA5r5cqV1s6dO62CggIrJSXF8vv9lmVZ1sSJE63p06fb/d9++20rPj7emj9/vrVr1y5rzpw5fI36LLV3rp988kkrMTHR+p//+R/r//7v/+ztyJEj0ToFI7R3nk/Ft5DOTnvnee/evVavXr2soqIia/fu3dbatWut1NRUa+7cudE6BSO0d57nzJlj9erVy/rNb35j7dmzx3rjjTesSy+91LrjjjuidQpGOHLkiLV161Zr69atliTr6aeftrZu3Wp9/PHHlmVZ1vTp062JEyfa/U9+jfqhhx6ydu3aZS1evJivUUfDokWLrH79+lmJiYnWddddZ73zzjv2vm9+85vWpEmTQvr/9re/tS677DIrMTHRuvzyy63XXnutkys2V3vm+qKLLrIknbbNmTOn8ws3THv/TH8eAebstXeeN23aZGVnZ1sOh8O65JJLrCeeeMJqbm7u5KrN0555bmpqsn76059al156qZWUlGRlZGRY999/v3X48OHOL9wgb775Zpt/356c20mTJlnf/OY3Tztm6NChVmJionXJJZdYL7zwQkRrjLEs1tAAAIBZuAcGAAAYhwADAACMQ4ABAADGIcAAAADjEGAAAIBxCDAAAMA4BBgAAGAcAgwAADAOAQYAABiHAAMAAIxDgAEAAMYhwAAAAOP8P1Ab5hsyrdYoAAAAAElFTkSuQmCC",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from matplotlib import pyplot\n",
    "bins = np.linspace(0, 1, 100)\n",
    "\n",
    "pyplot.hist(match_obj.df_out_final['pscore_treat'].values, bins, alpha=0.5, label='x')\n",
    "pyplot.hist(match_obj.df_out_final['pscore_control'].values, bins, alpha=0.5, label='y')\n",
    "pyplot.legend(loc='upper right')\n",
    "pyplot.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a234cf19",
   "metadata": {},
   "source": [
    "### 2.2 PSM with multiple p-score model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "8a96946e",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[LightGBM] [Info] Number of positive: 4077, number of negative: 3923\n",
      "[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.000279 seconds.\n",
      "You can set `force_col_wise=true` to remove the overhead.\n",
      "[LightGBM] [Info] Total Bins 791\n",
      "[LightGBM] [Info] Number of data points in the train set: 8000, number of used features: 16\n",
      "[LightGBM] [Info] [binary:BoostFromScore]: pavg=0.509625 -> initscore=0.038505\n",
      "[LightGBM] [Info] Start training from score 0.038505\n",
      "The f1 score for all models you specify is: [0.5773809523809523, 0.5805354866128347, 0.6130055511498811, 0.5143134400776322, 0.5076242006886375, 0.5004793863854267, 0.5656050955414013, 0.5198863636363636, 0.5149928605425987]\n",
      "The best model is the 2 model\n",
      "(       Covariates  Mean Treated post-match  Mean Control post-match   SMD  \\\n",
      "0             c_1                   0.5004                   0.4999  0.00   \n",
      "1             c_2                   0.5024                   0.5040 -0.01   \n",
      "2             c_3                   0.5033                   0.5033  0.00   \n",
      "3         d_1_bee                   0.1124                   0.1122  0.00   \n",
      "4         d_1_cat                   0.1581                   0.1579  0.00   \n",
      "5         d_1_dog                   0.1929                   0.1868  0.02   \n",
      "6        d_1_pear                   0.2551                   0.2670 -0.03   \n",
      "7     gender_cat1                   0.0714                   0.0724 -0.00   \n",
      "8     gender_cat2                   0.0883                   0.0865  0.01   \n",
      "9     gender_cat3                   0.1683                   0.1664  0.01   \n",
      "10    gender_cat4                   0.0247                   0.0247  0.00   \n",
      "11    gender_cat5                   0.1156                   0.1165 -0.00   \n",
      "12     gender_dog                   0.1776                   0.1848 -0.02   \n",
      "13  gender_female                   0.2046                   0.1982  0.02   \n",
      "14    gender_male                   0.0263                   0.0263  0.00   \n",
      "15     gender_pig                   0.0736                   0.0757 -0.01   \n",
      "\n",
      "    Var Ratio  ks-p_val  ttest-p_val  \n",
      "0        1.03     0.176        0.928  \n",
      "1        0.98     0.684        0.783  \n",
      "2        0.98     0.667        1.000  \n",
      "3         NaN     1.000        0.975  \n",
      "4         NaN     1.000        0.978  \n",
      "5         NaN     1.000        0.434  \n",
      "6         NaN     0.854        0.169  \n",
      "7         NaN     1.000        0.848  \n",
      "8         NaN     1.000        0.752  \n",
      "9         NaN     1.000        0.791  \n",
      "10        NaN     1.000        1.000  \n",
      "11        NaN     1.000        0.877  \n",
      "12        NaN     0.999        0.341  \n",
      "13        NaN     1.000        0.415  \n",
      "14        NaN     1.000        1.000  \n",
      "15        NaN     1.000        0.679  ,        Covariates  Mean Treated pre-match  Mean Control pre-match   SMD  \\\n",
      "0             c_1                  0.5004                  0.4984  0.01   \n",
      "1             c_2                  0.5024                  0.5062 -0.01   \n",
      "2             c_3                  0.5033                  0.5053 -0.01   \n",
      "3         d_1_bee                  0.1124                  0.1201 -0.02   \n",
      "4         d_1_cat                  0.1581                  0.1577  0.00   \n",
      "5         d_1_dog                  0.1929                  0.1997 -0.02   \n",
      "6        d_1_pear                  0.2551                  0.2541  0.00   \n",
      "7     gender_cat1                  0.0714                  0.0671  0.02   \n",
      "8     gender_cat2                  0.0883                  0.0938 -0.02   \n",
      "9     gender_cat3                  0.1683                  0.1691 -0.00   \n",
      "10    gender_cat4                  0.0247                  0.0288 -0.03   \n",
      "11    gender_cat5                  0.1156                  0.1167 -0.00   \n",
      "12     gender_dog                  0.1776                  0.1756  0.01   \n",
      "13  gender_female                  0.2046                  0.1964  0.02   \n",
      "14    gender_male                  0.0263                  0.0249  0.01   \n",
      "15     gender_pig                  0.0736                  0.0793 -0.02   \n",
      "\n",
      "    Var Ratio  ks-p_val  ttest-p_val  \n",
      "0        0.99     0.406        0.722  \n",
      "1        0.99     0.736        0.516  \n",
      "2        0.99     0.904        0.736  \n",
      "3         NaN     0.998        0.229  \n",
      "4         NaN     1.000        0.948  \n",
      "5         NaN     1.000        0.391  \n",
      "6         NaN     1.000        0.916  \n",
      "7         NaN     1.000        0.396  \n",
      "8         NaN     1.000        0.336  \n",
      "9         NaN     1.000        0.921  \n",
      "10        NaN     1.000        0.211  \n",
      "11        NaN     1.000        0.863  \n",
      "12        NaN     1.000        0.798  \n",
      "13        NaN     0.995        0.305  \n",
      "14        NaN     1.000        0.656  \n",
      "15        NaN     1.000        0.278  )\n",
      "Output dataframe columns Index(['user_id', 'treatment', 'pscore'], dtype='object')\n",
      "                  0         1\n",
      "const      6.340721  0.008059\n",
      "treatment  0.498932  0.103476\n"
     ]
    }
   ],
   "source": [
    "# STEP 0: Define all p-score models you want to fit\n",
    "ps_model1 = LogisticRegression(C=1e6)\n",
    "ps_model2 = SVC(probability=True)\n",
    "ps_model3 = GaussianNB()\n",
    "ps_model4 = KNeighborsClassifier()\n",
    "ps_model5 = DecisionTreeClassifier()\n",
    "ps_model6 = RandomForestClassifier()\n",
    "ps_model7 = GradientBoostingClassifier()\n",
    "ps_model8 = LGBMClassifier()\n",
    "ps_model9 = XGBClassifier()\n",
    "\n",
    "\n",
    "# --  STEP 0: we chooise model with the best f1 score\n",
    "model_list = [ps_model1, ps_model2, ps_model3, \n",
    "              ps_model4, ps_model5, ps_model6, \n",
    "              ps_model7, ps_model8, ps_model9]\n",
    "\n",
    "\n",
    "# STEP 1: initialize matching object\n",
    "match_obj = matching(data = df,     \n",
    "                     T = T,\n",
    "                     X = X,\n",
    "                     id = id)\n",
    "\n",
    "# STEP 2: propensity score matching\n",
    "match_obj.psm(n_neighbors = 1,\n",
    "              model_list = model_list, # input list of models you want to try\n",
    "              trim_percentage = 0,\n",
    "              caliper = 1,              \n",
    "              test_size = 0.2) # train-test split, what portion does test sample takes\n",
    "\n",
    "\n",
    "\n",
    "# STEP 3: balance check after propensity score matching\n",
    "print(match_obj.balance_check(include_discrete = True))\n",
    "\n",
    "\n",
    "# STEP 4: obtain average partial effect \n",
    "# -- match_obj.df_out_final: is the post matching dataframe WITH NO TRIMMING and stored in pairs.\n",
    "# -- match_obj.df_out_final_post_trim: post matching dataframe afrer trimming observations based on caliper and p-score percentage\n",
    "print('Output dataframe columns', match_obj.df_out_final_post_trim.columns)\n",
    "df_out = match_obj.df_out_final_post_trim.merge(df[y + X + [id]], how='left', on = id)\n",
    "\n",
    "X_mat = df_out[T]\n",
    "y_mat = df_out[y]\n",
    "\n",
    "X_mat = sm.add_constant(X_mat)\n",
    "model = sm.OLS(y_mat,X_mat)\n",
    "results = model.fit()\n",
    "print(results.params)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "314d16be",
   "metadata": {},
   "source": [
    "### 2.3 PSM with ATE generated for you\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "76692a8a",
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(       Covariates  Mean Treated post-match  Mean Control post-match   SMD  \\\n",
      "0             c_1                   0.5004                   0.4984  0.01   \n",
      "1             c_2                   0.5024                   0.5027 -0.00   \n",
      "2             c_3                   0.5033                   0.5031  0.00   \n",
      "3         d_1_bee                   0.1124                   0.1034  0.03   \n",
      "4         d_1_cat                   0.1581                   0.1727 -0.04   \n",
      "5         d_1_dog                   0.1929                   0.1878  0.01   \n",
      "6        d_1_pear                   0.2551                   0.2466  0.02   \n",
      "7     gender_cat1                   0.0714                   0.0667  0.02   \n",
      "8     gender_cat2                   0.0883                   0.0963 -0.03   \n",
      "9     gender_cat3                   0.1683                   0.1752 -0.02   \n",
      "10    gender_cat4                   0.0247                   0.0290 -0.03   \n",
      "11    gender_cat5                   0.1156                   0.1279 -0.04   \n",
      "12     gender_dog                   0.1776                   0.1695  0.02   \n",
      "13  gender_female                   0.2046                   0.1923  0.03   \n",
      "14    gender_male                   0.0263                   0.0241  0.01   \n",
      "15     gender_pig                   0.0736                   0.0783 -0.02   \n",
      "\n",
      "    Var Ratio  ks-p_val  ttest-p_val  \n",
      "0        1.00     0.058        0.724  \n",
      "1        1.01     0.504        0.967  \n",
      "2        0.99     0.766        0.965  \n",
      "3         NaN     0.984        0.142  \n",
      "4         NaN     0.651        0.049  \n",
      "5         NaN     1.000        0.512  \n",
      "6         NaN     0.993        0.326  \n",
      "7         NaN     1.000        0.349  \n",
      "8         NaN     0.996        0.161  \n",
      "9         NaN     1.000        0.358  \n",
      "10        NaN     1.000        0.178  \n",
      "11        NaN     0.826        0.056  \n",
      "12        NaN     0.996        0.284  \n",
      "13        NaN     0.826        0.118  \n",
      "14        NaN     1.000        0.487  \n",
      "15        NaN     1.000        0.370  ,        Covariates  Mean Treated pre-match  Mean Control pre-match   SMD  \\\n",
      "0             c_1                  0.5004                  0.4984  0.01   \n",
      "1             c_2                  0.5024                  0.5062 -0.01   \n",
      "2             c_3                  0.5033                  0.5053 -0.01   \n",
      "3         d_1_bee                  0.1124                  0.1201 -0.02   \n",
      "4         d_1_cat                  0.1581                  0.1577  0.00   \n",
      "5         d_1_dog                  0.1929                  0.1997 -0.02   \n",
      "6        d_1_pear                  0.2551                  0.2541  0.00   \n",
      "7     gender_cat1                  0.0714                  0.0671  0.02   \n",
      "8     gender_cat2                  0.0883                  0.0938 -0.02   \n",
      "9     gender_cat3                  0.1683                  0.1691 -0.00   \n",
      "10    gender_cat4                  0.0247                  0.0288 -0.03   \n",
      "11    gender_cat5                  0.1156                  0.1167 -0.00   \n",
      "12     gender_dog                  0.1776                  0.1756  0.01   \n",
      "13  gender_female                  0.2046                  0.1964  0.02   \n",
      "14    gender_male                  0.0263                  0.0249  0.01   \n",
      "15     gender_pig                  0.0736                  0.0793 -0.02   \n",
      "\n",
      "    Var Ratio  ks-p_val  ttest-p_val  \n",
      "0        0.99     0.406        0.722  \n",
      "1        0.99     0.736        0.516  \n",
      "2        0.99     0.904        0.736  \n",
      "3         NaN     0.998        0.229  \n",
      "4         NaN     1.000        0.948  \n",
      "5         NaN     1.000        0.391  \n",
      "6         NaN     1.000        0.916  \n",
      "7         NaN     1.000        0.396  \n",
      "8         NaN     1.000        0.336  \n",
      "9         NaN     1.000        0.921  \n",
      "10        NaN     1.000        0.211  \n",
      "11        NaN     1.000        0.863  \n",
      "12        NaN     1.000        0.798  \n",
      "13        NaN     0.995        0.305  \n",
      "14        NaN     1.000        0.656  \n",
      "15        NaN     1.000        0.278  )\n",
      "    y       ate         p_val\n",
      "0   y  0.512805  3.793976e-29\n",
      "1  y2  0.104296  1.281648e-07\n"
     ]
    }
   ],
   "source": [
    "# STEP 0: Define all p-score models you want to fit\n",
    "ps_model1 = LogisticRegression(C=1e6)\n",
    "ps_model2 = SVC(probability=True)\n",
    "ps_model3 = GaussianNB()\n",
    "ps_model4 = KNeighborsClassifier()\n",
    "ps_model5 = DecisionTreeClassifier()\n",
    "ps_model6 = RandomForestClassifier()\n",
    "ps_model7 = GradientBoostingClassifier()\n",
    "ps_model8 = LGBMClassifier()\n",
    "ps_model9 = XGBClassifier()\n",
    "\n",
    "\n",
    "# we chooise model with best f1 score\n",
    "model_list = [ps_model1, ps_model2]\n",
    "\n",
    "\n",
    "# STEP 1: initialize matching object\n",
    "match_obj = matching(data = df,     \n",
    "                     T = T,\n",
    "                     X = X,\n",
    "                     y = y, # you have to identify dependent variable name if want to use ATE function\n",
    "                     id = id)\n",
    "\n",
    "# STEP 2: propensity score matching\n",
    "match_obj.psm(n_neighbors = 1,\n",
    "              model = LogisticRegression(C=1e6), # input list of models you want to try\n",
    "              trim_percentage = 0,\n",
    "              caliper = 1,              \n",
    "              test_size = 0.2) # train-test split, what portion does test sample takes\n",
    "\n",
    "\n",
    "\n",
    "# STEP 3: balance check after propensity score matching\n",
    "print(match_obj.balance_check(include_discrete = True))\n",
    "\n",
    "\n",
    "# STEP 4: obtain average partial effect \n",
    "print(match_obj.ate())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "25cb9aba",
   "metadata": {},
   "source": [
    "## 3. CEM\n",
    "### 3.1 Simple CEM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "03d98668",
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "number of matched obs (9768, 15) number of total obs  (10000, 15)\n",
      "(   Covariates  Mean Treated post-match  Mean Control post-match   SMD  \\\n",
      "0         c_1                   0.4999                   0.4983  0.01   \n",
      "1     d_1_bee                   0.1114                   0.1194 -0.03   \n",
      "2     d_1_cat                   0.1546                   0.1570 -0.01   \n",
      "3     d_1_dog                   0.1929                   0.2002 -0.02   \n",
      "4    d_1_pear                   0.2588                   0.2545  0.01   \n",
      "5     d_3_1.0                   0.2039                   0.1882  0.04   \n",
      "6     d_3_2.0                   0.0516                   0.0516  0.00   \n",
      "7     d_3_3.0                   0.1790                   0.1804 -0.00   \n",
      "8     d_3_4.0                   0.0704                   0.0713 -0.00   \n",
      "9     d_3_5.0                   0.0692                   0.0676  0.01   \n",
      "10    d_3_6.0                   0.0921                   0.0938 -0.01   \n",
      "11    d_3_7.0                   0.1824                   0.1835 -0.00   \n",
      "12    d_3_8.0                   0.0221                   0.0211  0.01   \n",
      "13    d_3_9.0                   0.1063                   0.1140 -0.02   \n",
      "\n",
      "    Var Ratio  ks-p_val  ttest-p_val  \n",
      "0        0.99     0.445        0.789  \n",
      "1         NaN     0.997        0.217  \n",
      "2         NaN     1.000        0.738  \n",
      "3         NaN     0.999        0.359  \n",
      "4         NaN     1.000        0.627  \n",
      "5         NaN     0.573        0.050  \n",
      "6         NaN     1.000        1.000  \n",
      "7         NaN     1.000        0.854  \n",
      "8         NaN     1.000        0.875  \n",
      "9         NaN     1.000        0.748  \n",
      "10        NaN     1.000        0.780  \n",
      "11        NaN     1.000        0.896  \n",
      "12        NaN     1.000        0.728  \n",
      "13        NaN     0.998        0.219  ,    Covariates  Mean Treated pre-match  Mean Control pre-match   SMD  \\\n",
      "0         c_1                  0.5004                  0.4984  0.01   \n",
      "1     d_1_bee                  0.1124                  0.1201 -0.02   \n",
      "2     d_1_cat                  0.1581                  0.1577  0.00   \n",
      "3     d_1_dog                  0.1929                  0.1997 -0.02   \n",
      "4    d_1_pear                  0.2551                  0.2541  0.00   \n",
      "5     d_3_1.0                  0.2017                  0.1874  0.04   \n",
      "6     d_3_2.0                  0.0520                  0.0514  0.00   \n",
      "7     d_3_3.0                  0.1787                  0.1797 -0.00   \n",
      "8     d_3_4.0                  0.0691                  0.0710 -0.01   \n",
      "9     d_3_5.0                  0.0687                  0.0673  0.01   \n",
      "10    d_3_6.0                  0.0932                  0.0934 -0.00   \n",
      "11    d_3_7.0                  0.1809                  0.1827 -0.00   \n",
      "12    d_3_8.0                  0.0253                  0.0226  0.02   \n",
      "13    d_3_9.0                  0.1052                  0.1136 -0.03   \n",
      "\n",
      "    Var Ratio  ks-p_val  ttest-p_val  \n",
      "0        0.99     0.406        0.722  \n",
      "1         NaN     0.998        0.229  \n",
      "2         NaN     1.000        0.948  \n",
      "3         NaN     1.000        0.391  \n",
      "4         NaN     1.000        0.916  \n",
      "5         NaN     0.685        0.072  \n",
      "6         NaN     1.000        0.893  \n",
      "7         NaN     1.000        0.901  \n",
      "8         NaN     1.000        0.707  \n",
      "9         NaN     1.000        0.787  \n",
      "10        NaN     1.000        0.970  \n",
      "11        NaN     1.000        0.810  \n",
      "12        NaN     1.000        0.383  \n",
      "13        NaN     0.994        0.176  )\n",
      "   y       ate         p_val\n",
      "0  y  0.500467  1.832978e-26\n"
     ]
    }
   ],
   "source": [
    "# STEP 1: initialize matching object\n",
    "match_obj_cem = matching(data = df, \n",
    "                         y = ['y'],\n",
    "                         T = 'treatment', \n",
    "                         X = ['c_1','d_1','d_3'],\n",
    "                         id = 'user_id')\n",
    "\n",
    "# STEP 2: coarsened exact matching\n",
    "match_obj_cem.cem(n_bins = 10, # number of bins you set to divide continuous variables, user pd.qcut function to obatin\n",
    "                  k2k = True)  # k2k: make sure number of treatment equals number of control. if is false, you need to apply weighted least square to obtain ATE\n",
    "\n",
    "\n",
    "# STEP 3: balance check after propensity score matching\n",
    "print(match_obj_cem.balance_check(include_discrete=True))\n",
    "\n",
    "# STEP 4: obtain average partial effect \n",
    "print(match_obj_cem.ate())\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "39c780f5",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "f8dd7111",
   "metadata": {},
   "source": [
    "### 3.2 CEM with customized bin\n",
    "If your continuous variable feature is very skewed, use customized breakpoint instead of percentile cut."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "94bab763",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "number of matched obs (9806, 15) number of total obs  (10000, 15)\n",
      "(   Covariates  Mean Treated post-match  Mean Control post-match   SMD  \\\n",
      "0         c_1                   0.5002                   0.4984  0.01   \n",
      "1     d_1_bee                   0.1103                   0.1201 -0.03   \n",
      "2     d_1_cat                   0.1575                   0.1577 -0.00   \n",
      "3     d_1_dog                   0.1934                   0.1997 -0.02   \n",
      "4    d_1_pear                   0.2549                   0.2541  0.00   \n",
      "5     d_3_1.0                   0.2025                   0.1874  0.04   \n",
      "6     d_3_2.0                   0.0512                   0.0514 -0.00   \n",
      "7     d_3_3.0                   0.1783                   0.1797 -0.00   \n",
      "8     d_3_4.0                   0.0706                   0.0710 -0.00   \n",
      "9     d_3_5.0                   0.0687                   0.0673  0.01   \n",
      "10    d_3_6.0                   0.0930                   0.0934 -0.00   \n",
      "11    d_3_7.0                   0.1813                   0.1827 -0.00   \n",
      "12    d_3_8.0                   0.0247                   0.0226  0.01   \n",
      "13    d_3_9.0                   0.1050                   0.1136 -0.03   \n",
      "\n",
      "    Var Ratio  ks-p_val  ttest-p_val  \n",
      "0        0.99     0.433        0.750  \n",
      "1         NaN     0.971        0.129  \n",
      "2         NaN     1.000        0.978  \n",
      "3         NaN     1.000        0.431  \n",
      "4         NaN     1.000        0.926  \n",
      "5         NaN     0.626        0.059  \n",
      "6         NaN     1.000        0.963  \n",
      "7         NaN     1.000        0.854  \n",
      "8         NaN     1.000        0.937  \n",
      "9         NaN     1.000        0.779  \n",
      "10        NaN     1.000        0.945  \n",
      "11        NaN     1.000        0.855  \n",
      "12        NaN     1.000        0.506  \n",
      "13        NaN     0.993        0.174  ,    Covariates  Mean Treated pre-match  Mean Control pre-match   SMD  \\\n",
      "0         c_1                  0.5004                  0.4984  0.01   \n",
      "1     d_1_bee                  0.1124                  0.1201 -0.02   \n",
      "2     d_1_cat                  0.1581                  0.1577  0.00   \n",
      "3     d_1_dog                  0.1929                  0.1997 -0.02   \n",
      "4    d_1_pear                  0.2551                  0.2541  0.00   \n",
      "5     d_3_1.0                  0.2017                  0.1874  0.04   \n",
      "6     d_3_2.0                  0.0520                  0.0514  0.00   \n",
      "7     d_3_3.0                  0.1787                  0.1797 -0.00   \n",
      "8     d_3_4.0                  0.0691                  0.0710 -0.01   \n",
      "9     d_3_5.0                  0.0687                  0.0673  0.01   \n",
      "10    d_3_6.0                  0.0932                  0.0934 -0.00   \n",
      "11    d_3_7.0                  0.1809                  0.1827 -0.00   \n",
      "12    d_3_8.0                  0.0253                  0.0226  0.02   \n",
      "13    d_3_9.0                  0.1052                  0.1136 -0.03   \n",
      "\n",
      "    Var Ratio  ks-p_val  ttest-p_val  \n",
      "0        0.99     0.406        0.722  \n",
      "1         NaN     0.998        0.229  \n",
      "2         NaN     1.000        0.948  \n",
      "3         NaN     1.000        0.391  \n",
      "4         NaN     1.000        0.916  \n",
      "5         NaN     0.685        0.072  \n",
      "6         NaN     1.000        0.893  \n",
      "7         NaN     1.000        0.901  \n",
      "8         NaN     1.000        0.707  \n",
      "9         NaN     1.000        0.787  \n",
      "10        NaN     1.000        0.970  \n",
      "11        NaN     1.000        0.810  \n",
      "12        NaN     1.000        0.383  \n",
      "13        NaN     0.994        0.176  )\n",
      "0.5032208003719669 7.70727152999287e-27\n",
      "ATE using OLS:     y       ate         p_val\n",
      "0  y  0.503221  7.707272e-27\n",
      "ATE using WLS:     y       ate         p_val\n",
      "0  y  0.498841  2.374094e-26\n"
     ]
    }
   ],
   "source": [
    "\n",
    "\n",
    "\n",
    "# STEP 1: initialize matching object\n",
    "match_obj_cem = matching(data = df, \n",
    "                       y = ['y'],\n",
    "                       T = 'treatment', \n",
    "                       X = ['c_1','d_1','d_3'],\n",
    "                       id = 'user_id')\n",
    "\n",
    "# STEP 2: coarsened exact matching\n",
    "match_obj_cem.cem(n_bins = 10, \n",
    "                  \n",
    "                  # continuous feature c_1 break down to 5 bins based on customized breakpoing\n",
    "                  # ->[-inf,-1),[-1, 0.3), [0.3, 0.6), [0.6, 2),[2,inf]\n",
    "                  break_points = {'c_1': [-1, 0.3, 0.6, 2]}, \n",
    "                  \n",
    "                  # discrete/string features group to a bigger group:\n",
    "                  # -- 1. d_1 has 5 values, customize to a 3 value group\n",
    "                  # -- 2. d_3 has 10 values, customize to a 3 value group\n",
    "                  cluster_criteria = {'d_1': [['apple','pear'],['cat','dog'],['bee']],\n",
    "                                      'd_3': [['0.0','1.0','2.0'],\n",
    "                                                ['3.0','4.0','5.0'],\n",
    "                                                ['6.0','7.0','8.0','9.0']]},\n",
    "                  k2k = True) \n",
    "\n",
    "\n",
    "# STEP 3: balance check after propensity score matching\n",
    "print(match_obj_cem.balance_check(include_discrete=True))\n",
    "\n",
    "# -- match result is stored in this dataframe: match_obj_cem.df_out_final\n",
    "Y = match_obj_cem.df_out_final['y']\n",
    "X = match_obj_cem.df_out_final['treatment']\n",
    "X = sm.add_constant(X)\n",
    "model = sm.OLS(Y,X)\n",
    "results = model.fit()\n",
    "print(results.params['treatment'],results.pvalues['treatment'])\n",
    "\n",
    "\n",
    "# STEP 4: obtain average partial effect \n",
    "print('ATE using OLS: ',match_obj_cem.ate())\n",
    "print('ATE using WLS: ',match_obj_cem.ate(use_weight=True))\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
